from __future__ import division

#import Startup
import ephere_moov as moov
import ephere_ornatrix as ox

import Math
import math
import random


def ClampedStiffness( stiffness, useCompliantConstraints ):
	return stiffness if useCompliantConstraints else min( 1.0, max( 0.0, stiffness ) )

def DistanceConstraint( id1, id2, stiffness, useCompliantConstraints = False ):
	clampedStiffness = ClampedStiffness( stiffness, useCompliantConstraints )
	return moov.ConstraintDescription( particleIds = [id1, id2], type = moov.ConstraintType.PBD_Distance, stiffness = clampedStiffness ) if not useCompliantConstraints \
		else moov.ConstraintDescription( particleIds = [id1, id2], type = moov.ConstraintType.XPBD_Distance, stiffness = clampedStiffness )

def BendingConstraint( id1, id2, id3, stiffness ):
	return moov.ConstraintDescription( particleIds = [id1, id2, id3], type = moov.ConstraintType.PBD_Bending, stiffness = stiffness )

def CapsuleConstraint( id1, id2, radius, collisionGroup ):
	return moov.ConstraintDescription( particleIds = [id1, id2], type = moov.ConstraintType.ColliderCapsule, radius = radius, collisionGroup = collisionGroup, frictionCoefficient = 0.5, restitutionCoefficient = 0 )

def StretchShearConstraint( particleIds, stiffness, useCompliantConstraints = False  ):
	clampedStiffness = ClampedStiffness( stiffness, useCompliantConstraints )
	return moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.PBD_StretchShear, stiffness = clampedStiffness ) if not useCompliantConstraints \
		else moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.XPBD_StretchShear, stiffness = clampedStiffness )

def BendTwistConstraint( particleIds, stiffness, useCompliantConstraints = False  ):
	clampedStiffness = ClampedStiffness( stiffness, useCompliantConstraints )
	return moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.PBD_BendTwist, stiffness = clampedStiffness ) if not useCompliantConstraints \
		else moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.XPBD_BendTwist, stiffness = clampedStiffness )

def ElastonBendTwistConstraint( particleIds, stiffness, useCompliantConstraints = False  ):
	clampedStiffness = ClampedStiffness( stiffness, useCompliantConstraints )
	return moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.PBD_ElastonBendTwist, stiffness = clampedStiffness ) if not useCompliantConstraints \
		else moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.XPBD_ElastonBendTwist, stiffness = clampedStiffness )

def ElastonStretchShearConstraint( particleIds, stiffness, useCompliantConstraints = False  ):
	clampedStiffness = ClampedStiffness( stiffness, useCompliantConstraints )
	return moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.PBD_ElastonStretchShear, stiffness = clampedStiffness ) if not useCompliantConstraints \
		else moov.ConstraintDescription( particleIds = particleIds, type = moov.ConstraintType.XPBD_ElastonStretchShear, stiffness = clampedStiffness )


class ModelType:
	"""Class enumerating the particle-constraint models used by HairModel.
	
	DistanceOnly - distance constraints are used for both stretching and bending
	DistanceBending - distance constraints are used for stretching, bending constraints for bending
	Cosserat - Cosserat rods stretch-shear and bend-twist constraints
	CosseratDistance - Cosserat rods constraints plus distance constraints for more stretching stiffness
	CosseratElaston - DEPRECATED (all Cosserat models use elastons now) - for backwards compatibility only
	"""

	DistanceOnly = 'DistanceOnly'
	DistanceBending = 'DistanceBending'
	Cosserat = 'Cosserat'
	CosseratDistance = 'CosseratDistance'
	CosseratElaston = 'CosseratElaston'


class ModelParameters:
	"""Container class for the parameters of a HairModel instance"""

	def __init__( self ):
		self.stretchingStiffness = 1.0
		self.stretchingCurve = None
		self.stretchingChannel = 0
		self.bendingStiffness = 1.0
		self.bendingCurve = None
		self.bendingChannel = 0
		self.latticeStiffness = 1.0
		self.latticeCount = 1
		self.latticeSize = 0.1
		self.massPerVertex = 1.0
		self.massCurve = None
		self.massChannel = 0
		self.rootHolderPosition = 0.5
		self.rootVertexCount = 0
		# string describing the type of physical model to create. Possible values are taken from class ModelType
		self.modelType = ModelType.Cosserat
		# if True, the model uses compliant constraints (unlimited stiffness)
		self.useCompliantConstraints = False
		# if True, adds stretch-limiting long-range constraints
		self.limitStretch = False
		# number of levels in long-range constraints hierarchy
		self.longRangeLayerCount = 0
		self.longRangeStiffness = 1.0
		# hold-together parameters
		self.useGroupHolder = False
		self.groupHolderGenerator = 'None'
		self.groupHolderPosMin = 5
		self.groupHolderPosMax = 10
		self.groupHolderMaxGroupCount = 100
		self.groupHolderRandomSeed = 1000
		self.groupHolderStiffness = 1.0
		# capsule parameters
		self.capsuleRadius = 0.1
		self.capsuleCollisionGroup = 1000
		# hair radius
		self.particleRadius = 0
		self.particleRadiusChannel = 0
		self.particleRadiusCurve = None
		# collision groups and collisions
		self.particlesCollisionGroup = 0
		self.baseMeshCollisionGroup = 1
		self.collisionMeshesCollisionGroup = 2
		self.meshFrictionCoefficient = 0.5
		self.meshRestitutionCoefficient = 0
		# Internal parameters
		self.lastFixedVertexIndex = 1
		self.useRootHolder = False

	def SetParameters( self, latticeCount = None, latticeSize = None, latticeStiffness = None, modelType = None, useCompliantConstraints = None, limitStretch = None,
					stretchingStiffness = None, bendingStiffness = None, massPerVertex = None, useRootHolder = None, rootHolderPosition = None, rootVertexCount = None,
					massCurve = None, stretchingStiffnessCurve = None, bendingStiffnessCurve = None, stretchingChannel = None, bendingChannel = None, 
					longRangeLayerCount = None, longRangeStiffness = None ):
		self.SetModelParameters( latticeCount = latticeCount, latticeSize = latticeSize, latticeStiffness = latticeStiffness, modelType = modelType, 
						useCompliantConstraints = useCompliantConstraints, limitStretch = limitStretch, 
						longRangeLayerCount = longRangeLayerCount, longRangeStiffness = longRangeStiffness )
		self.SetHairParameters( stretchingStiffness = stretchingStiffness, bendingStiffness = bendingStiffness, massPerVertex = massPerVertex, useRootHolder = useRootHolder,
						rootHolderPosition = rootHolderPosition, rootVertexCount = rootVertexCount, massCurve = massCurve, stretchingCurve = stretchingStiffnessCurve, 
						bendingCurve = bendingStiffnessCurve, stretchingChannel = stretchingChannel, bendingChannel = bendingChannel )

	def SetModelParameters( self, latticeCount = None, latticeSize = None, latticeStiffness = None, modelType = None, useCompliantConstraints = None, limitStretch = None, 
						longRangeLayerCount = None, longRangeStiffness = None ):
		if latticeCount is not None:
			self.latticeCount = latticeCount
		if latticeSize is not None:
			self.latticeSize = latticeSize
		if latticeStiffness is not None:
			self.latticeStiffness = latticeStiffness
		if modelType is not None:
			if modelType == ModelType.CosseratElaston:
				self.modelType = ModelType.Cosserat
				print( "WARNING: the CosseratElaston model is deprecated; using Cosserat instead" )
			else:
				self.modelType = modelType
		if useCompliantConstraints is not None:
			self.useCompliantConstraints = useCompliantConstraints
		if limitStretch is not None:
			self.limitStretch = limitStretch
		if longRangeLayerCount is not None:
			self.longRangeLayerCount = longRangeLayerCount
		if longRangeStiffness is not None:
			self.longRangeStiffness = longRangeStiffness

	def SetHairParameters( self, stretchingStiffness = None, bendingStiffness = None, massPerVertex = None, useRootHolder = None, rootHolderPosition = None, rootVertexCount = None,
						massCurve = None, stretchingCurve = None, bendingCurve = None, stretchingChannel = None, bendingChannel = None, massChannel = None ):
		if stretchingStiffness is not None:
			self.stretchingStiffness = stretchingStiffness
		if bendingStiffness is not None:
			self.bendingStiffness = bendingStiffness
		if massPerVertex is not None:
			self.massPerVertex = massPerVertex
		if useRootHolder is not None:
			self.useRootHolder = useRootHolder
		if rootHolderPosition is not None:
			self.rootHolderPosition = rootHolderPosition
		if rootVertexCount is not None:
			self.rootVertexCount = max( 0, rootVertexCount )
		self.lastFixedVertexIndex = self.rootVertexCount if self.useRootHolder else self.rootVertexCount + 1
		if massCurve is not None:
			self.massCurve = massCurve
		if stretchingCurve is not None:
			self.stretchingCurve = stretchingCurve
		if bendingCurve is not None:
			self.bendingCurve = bendingCurve
		if stretchingChannel is not None:
			self.stretchingChannel = stretchingChannel
		if bendingChannel is not None:
			self.bendingChannel = bendingChannel
		if massChannel is not None:
			self.massChannel = massChannel

	def SetGroupHolderParameters( self, useGroupHolder = None, groupHolderGenerator = None, groupHolderPosMin = None, groupHolderPosMax = None, groupHolderMaxGroupCount = None, groupHolderRandomSeed = None, groupHolderStiffness = None ):
		if useGroupHolder is not None:
			self.useGroupHolder = useGroupHolder
		if groupHolderGenerator is not None:
			self.groupHolderGenerator = groupHolderGenerator
		if groupHolderPosMin is not None:
			self.groupHolderPosMin = groupHolderPosMin
		if groupHolderPosMax is not None:
			self.groupHolderPosMax = groupHolderPosMax
		if groupHolderMaxGroupCount is not None:
			self.groupHolderMaxGroupCount = groupHolderMaxGroupCount
		if groupHolderRandomSeed is not None:
			self.groupHolderRandomSeed = groupHolderRandomSeed
		if groupHolderStiffness is not None:
			self.groupHolderStiffness = groupHolderStiffness

	def SetCapsuleParameters( self, capsuleRadius = None, capsuleCollisionGroup = None ):
		if capsuleRadius is not None:
			self.capsuleRadius = capsuleRadius
		# TODO: ? capsule collision group by strand
		if capsuleCollisionGroup is not None:
			self.capsuleCollisionGroup = capsuleCollisionGroup

	def SetCollisionParameters( self, particleRadius = None, particleRadiusChannel = None, particleRadiusCurve = None, meshFrictionCoefficient = None, meshRestitutionCoefficient = None ):
		if particleRadius is not None:
			self.particleRadius = particleRadius
		if particleRadiusChannel is not None:
			self.particleRadiusChannel = particleRadiusChannel
		if particleRadiusCurve is not None:
			self.particleRadiusCurve = particleRadiusCurve
		if meshFrictionCoefficient is not None:
			self.meshFrictionCoefficient = meshFrictionCoefficient
		if meshRestitutionCoefficient is not None:
			self.meshRestitutionCoefficient = meshRestitutionCoefficient

	def SetCollisionGroupParameters( self, particlesCollisionGroup = None, baseMeshCollisionGroup = None, collisionMeshesCollisionGroup = None ):
		if particlesCollisionGroup is not None:
			self.particlesCollisionGroup = particlesCollisionGroup
		if baseMeshCollisionGroup is not None:
			self.baseMeshCollisionGroup = baseMeshCollisionGroup
		if collisionMeshesCollisionGroup is not None:
			self.collisionMeshesCollisionGroup = collisionMeshesCollisionGroup

	def SetMeshParameters( self, meshFrictionCoefficient = None, meshRestitutionCoefficient = None ):
		if meshFrictionCoefficient is not None:
			self.meshFrictionCoefficient = meshFrictionCoefficient
		if meshRestitutionCoefficient is not None:
			self.meshRestitutionCoefficient = meshRestitutionCoefficient


class StrandChannel:
	"""Class for working with strand data channels."""
	def __init__( self, hair, channelSelectorValue ):
		self.hair = hair
		self.channelIndex, self.channelType = StrandChannel.GetStrandChannelIndexAndType( channelSelectorValue )
		self.isValidChannel = StrandChannel.IsValidChannel( self.hair, self.channelIndex, self.channelType )

	def GetValue( self, strandIndex, strandPointIndex ):
		return StrandChannel.GetStrandChannelValue( self.hair, self.channelIndex, self.channelType, strandIndex, strandPointIndex )

	@staticmethod
	def GetStrandChannelIndexAndType( channelSelectorValue ):
		"""Returns the strand channel index and type from the ChannelSelector parameter value, or (-1, None) if no channel was selected"""
		indexTypePair = ( -1, None ) if channelSelectorValue is None or channelSelectorValue <= 0 \
			else ( channelSelectorValue - 1, ox.StrandDataType.PerStrand ) if channelSelectorValue < 1001 \
			else ( channelSelectorValue - 1001, ox.StrandDataType.PerVertex )
		return indexTypePair

	@staticmethod
	def IsValidChannel( hair, channelIndex, channelType ):
		if channelType is None:
			return False
		channelCount = hair.GetStrandChannelCount( channelType )
		return channelIndex < channelCount

	@staticmethod
	def GetStrandChannelValue( hair, channelIndex, channelType, strandIndex, strandPointIndex ):
		value = hair.GetStrandChannelData( channelIndex, strandIndex ) if channelType == ox.StrandDataType.PerStrand \
			else hair.GetStrandChannelData( channelIndex, strandIndex, strandPointIndex ) if channelType == ox.StrandDataType.PerVertex \
			else None
		return value


class PropagatedStrandData:
	def __init__( self, hair, strandIndex ):
		self.isPropagated = False
		if not hair.KeepsSurfaceDependency():
			return
		surfaceDependency = hair.GetSurfaceDependency( strandIndex )
		self.isPropagated = surfaceDependency.barycentricCoordinate[2] < -0.5
		if not self.isPropagated:
			return
		self.sourceStrand = surfaceDependency.faceIndex
		sourceStrandPointCount = hair.GetStrandPointCount( self.sourceStrand )
		positionOnStrand = min( max( surfaceDependency.barycentricCoordinate[0], 0.0 ), 1.0 ) * ( sourceStrandPointCount - 1 )
		self.sourcePointIndex1 = int( positionOnStrand )
		self.sourcePointIndex2 = min( self.sourcePointIndex1 + 1, sourceStrandPointCount - 1 )
		self.fraction = self.sourcePointIndex1 + 1 - positionOnStrand
		self.fraction = max( 0.0, min( 1.0, self.fraction ) )
		#print( "strand {0} source {1}: points {2} {3}, fraction {4}".format( strandIndex, self.sourceStrand, self.sourcePointIndex1, self.sourcePointIndex2, self.fraction ) )


class StrandModel:
	"""Contains particle/constraint data and solver sets for a strand of hair."""

	def __init__( self, hair, strandIndex, modelParameters ):
		"""Initializes a strand model."""
		self.particleSet = None
		self.constraintSet = None
		self.propagatedRootConstraintSet = None
		self.capsuleSet = None
		self.particleIds = []
		self.params = modelParameters
		self.propagatedData = PropagatedStrandData( hair, strandIndex )

	def IsPropagated( self ):
		return self.propagatedData.isPropagated

	@staticmethod
	def GetRampAndChannelMultipliers( hair, strandIndex, curve, channelSelectorValue, minValue = None, maxValue = None ):
		strandChannel = StrandChannel( hair, channelSelectorValue )
		if curve is None and not strandChannel.isValidChannel:
			return None
		strandPointCount = hair.GetStrandPointCount( strandIndex )
		result = []
		for index in range( strandPointCount ):
			multiplier = 1.0
			if strandChannel.isValidChannel:
				multiplier *= strandChannel.GetValue( strandIndex, index )
			if curve is not None:
				multiplier *= curve.Evaluate( float( index ) / ( strandPointCount - 1 ) )
			if minValue is not None:
				multiplier = max( multiplier, minValue )
			if maxValue is not None:
				multiplier = min( multiplier, maxValue )
			result.append( multiplier )
		return result


	def CreateSolverObjects( self, solver, hair, strandIndex ):
		'''Creates solver particles and constraints. Raises exception if unsuccessful.'''
		self.ClearSolverObjects( solver )
		retryCount = 0
		maxRetries = 5
		while retryCount < maxRetries and not self.TryCreateSolverObjects( solver, hair, strandIndex ):
			retryCount += 1
			print( 'Solver objects not created, retry {0}/{1}'.format( retryCount, maxRetries ) )
		if retryCount >= maxRetries:
			raise RuntimeError( 'Could not create Moov solver objects' )

	def TryCreateSolverObjects( self, solver, hair, strandIndex ):
		"""Creates solver particles and constraints."""
		latticeCount = self.params.latticeCount

		self.particleIds = []
		useElastons = self.params.modelType == ModelType.Cosserat or self.params.modelType == ModelType.CosseratDistance
		particleData = self.GetParticleDescriptions( hair, strandIndex, useElastons, self.params.latticeSize )
		self.particleSet = solver.CreateParticles( particleData, moov.ParticleInformation.All, self.particleIds )
		if len( self.particleSet ) != len( particleData ):
			self.ClearSolverObjects( solver )
			return False

		stretchingStiffness = self.params.stretchingStiffness
		bendingStiffness = self.params.bendingStiffness

		constraintData = []

		createRootHolder = self.params.useRootHolder and not self.propagatedData.isPropagated 

		# exclude root holder from standard constraints
		particleIdsWithoutRootHolder = self.particleIds[:-self.params.latticeCount] if createRootHolder else self.particleIds

		stretchingMultipliers = self.GetRampAndChannelMultipliers( hair.GetHair(), strandIndex, self.params.stretchingCurve, self.params.stretchingChannel, 0, 1 )
		bendingMultipliers = self.GetRampAndChannelMultipliers( hair.GetHair(), strandIndex, self.params.bendingCurve, self.params.bendingChannel, 0, 1 )

		if( self.params.modelType == ModelType.DistanceOnly ):
			constraintData.extend( self.GetDistanceOnlyConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, 0, stretchingStiffness, bendingStiffness, self.params.useCompliantConstraints, stretchingMultipliers, bendingMultipliers ) )
		elif( self.params.modelType == ModelType.DistanceBending ):
			constraintData.extend( self.GetDistanceBendingConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, 0, stretchingStiffness, bendingStiffness, self.params.useCompliantConstraints, stretchingMultipliers, bendingMultipliers ) )
		elif( self.params.modelType == ModelType.Cosserat ):
			constraintData.extend( self.GetCosseratConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, 0, stretchingStiffness, bendingStiffness, self.params.useCompliantConstraints, True, stretchingMultipliers, bendingMultipliers ) )
		elif( self.params.modelType == ModelType.CosseratDistance ):
			constraintData.extend( self.GetLatticeLongitudinalConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, 0, stretchingStiffness, stretchingMultipliers, self.params.useCompliantConstraints ) )
			constraintData.extend( self.GetCosseratConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, 0, stretchingStiffness, bendingStiffness, self.params.useCompliantConstraints, True, stretchingMultipliers, bendingMultipliers ) )
		else:
			raise ValueError( "Unknown model type." )

		# Add transverse latice constraints starting from the first dynamic strand vertex. It is still possible to have a lattice with Cosserat rods, although useless
		constraintData.extend( self.GetLatticeTransverseConstraintDescriptions( self.particleIds, latticeCount, 
			self.params.lastFixedVertexIndex * latticeCount, self.params.latticeStiffness, self.params.useCompliantConstraints ) )

		if self.params.limitStretch:
			constraintData.extend( self.GetStretchLimitingConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, self.params.lastFixedVertexIndex * latticeCount, stretchingStiffness ) )

		if self.params.longRangeLayerCount > 0:
			constraintData.extend( self.GetLongRangeConstraintDescriptions( particleIdsWithoutRootHolder, latticeCount, self.params.lastFixedVertexIndex * latticeCount, self.params.longRangeLayerCount, stretchingStiffness, self.params.longRangeStiffness, self.params.useCompliantConstraints ) )

		if createRootHolder:
			constraintData.extend( self.GetRootHolderConstraintDescriptions( self.particleIds, latticeCount, self.params.lastFixedVertexIndex, bendingStiffness, self.params.useCompliantConstraints ) )

		self.constraintSet = solver.CreateConstraints( constraintData, moov.ConstraintInformation.All )
		if len( self.constraintSet ) != len( constraintData ):
			self.ClearSolverObjects( solver )
			return False

		return True

	def ResetCapsules( self, solver, create = False ):
		if self.capsuleSet is not None:
			solver.RemoveConstraints( self.capsuleSet )
			self.capsuleSet = None
		if create:
			# One capsule per strand segment
			particleCount = len( self.particleIds ) - self.params.latticeCount * ( 2 if self.params.useRootHolder and not self.IsPropagated() else 1 )
			indexRange = range( 0, particleCount, self.params.latticeCount )
			capsuleData = [CapsuleConstraint( self.particleIds[index], self.particleIds[index + self.params.latticeCount], self.params.capsuleRadius, self.params.capsuleCollisionGroup ) for index in indexRange]
			self.capsuleSet = solver.CreateConstraints( capsuleData, moov.ConstraintInformation.All )

	def ClearSolverObjects( self, solver ):
		self.ClearConstraints( solver )
		self.ClearParticles( solver )

	def ClearParticles( self, solver ):
		self.particleIds = []
		if self.particleSet is not None:
			solver.RemoveParticles( self.particleSet )
			self.particleSet = None

	def ClearConstraints( self, solver ):
		self.ResetCapsules( solver )
		if self.propagatedRootConstraintSet is not None:
			solver.RemoveConstraints( self.propagatedRootConstraintSet )
			self.propagatedRootConstraintSet = None
		if self.constraintSet is not None:
			solver.RemoveConstraints( self.constraintSet )
			self.constraintSet = None

	def GetVertexParticleIds( self, strandPointIndex ):
		"""Returns a list of particle ids corresponding to a given strand vertex."""
		#if ( strandPointIndex + 1 ) * self.params.latticeCount > len( self.particleIds ):
		#	raise ValueError( "Strand point index out of range." )
		#return [self.particleIds[index] for index in range( self.params.latticeCount * strandPointIndex, self.params.latticeCount * ( strandPointIndex + 1 ) )]
		return self.particleIds[self.params.latticeCount * strandPointIndex:self.params.latticeCount * ( strandPointIndex + 1 )]

	def GetParticleId( self, strandPointIndex, latticeIndex ):
		"""Returns the particle id corresponding to a given lattice position in a given strand vertex."""
		if ( strandPointIndex + 1 ) * self.params.latticeCount > len( self.particleIds ):
			raise ValueError( "Strand point index out of range." )
		return self.particleIds[self.params.latticeCount * strandPointIndex + latticeIndex]


	def GetLatticeNodePositions( self, hair, strandIndex, strandPointIndex, strandPoints = None ):
		"""Returns a list containing the lattice positions around a strand vertex."""
		if ( strandPoints is not None and strandPointIndex >= len( strandPoints ) ) or strandIndex >= hair.GetStrandCount() or ( strandPoints is None and strandPointIndex >= hair.GetStrandPointCount( strandIndex ) ):
			raise ValueError( 'Strand index/strand point index out of range.' )

		result = []
		oxPos = hair.GetStrandPointInWorldCoordinates( strandIndex, strandPointIndex ) if strandPoints is None else strandPoints[strandPointIndex]

		if self.params.latticeCount < 2:
			result.append( moov.Vector3( oxPos[0], oxPos[1], oxPos[2] ) )
			return result

		strandPoint = Math.Vector3f( oxPos )
		if strandPointIndex > 0:
			oxPrevPos = hair.GetStrandPointInWorldCoordinates( strandIndex, strandPointIndex - 1 ) if strandPoints is None else strandPoints[strandPointIndex - 1]
			prevStrandPoint = Math.Vector3f( oxPrevPos )
			axis = strandPoint - prevStrandPoint
		else:
			oxNextPos = hair.GetStrandPointInWorldCoordinates( strandIndex, strandPointIndex + 1 ) if strandPoints is None else strandPoints[strandPointIndex + 1]
			nextStrandPoint = Math.Vector3f( oxNextPos )
			axis = nextStrandPoint - strandPoint
		axis = axis.normalize()

		# positionOffset is center-to-vertex vector for the lattice polygon
		if hair.IsUsingPerStrandRotationAngles():
			positionOffset = Math.Vector3f( hair.GetStrandLatticeOffset( strandIndex ) ).normalize() * self.params.latticeSize
		else:
			# Need a start offset; choose horizontal normal to axis (-y is down)
			horizontalNormal = axis.cross( Math.Vector3f( 0, -1, 0 ) )
			if horizontalNormal.lengthSquared() < 1e-4:
				horizontalNormal = Math.Vector3f( 0, 0, 1 )
			positionOffset = horizontalNormal.normalize() * self.params.latticeSize

		rotationQuaternion = Math.Quaternion.FromAngleAxis( 2.0 * math.pi / self.params.latticeCount, axis )
		rotationMatrix = rotationQuaternion.ToMatrix()

		# Hair vertex position corresponds to the first lattice particle
		result.append( moov.Vector3( strandPoint[0], strandPoint[1], strandPoint[2] ) )
		# Recompute positionOffset as polygon side instead of center-to-vertex vector
		positionOffset = rotationMatrix * positionOffset - positionOffset
		for index in range( 1, self.params.latticeCount ):
			strandPoint += positionOffset
			result.append( moov.Vector3( strandPoint[0], strandPoint[1], strandPoint[2] ) )
			positionOffset = rotationMatrix * positionOffset

		return result


	def GetParticleDescriptions( self, hair, strandIndex, useElastons = False, maxRandomOffset = 0 ):
		"""Returns a list of particle descriptions for the strand lattice."""
		strandPointCount = hair.GetStrandPointCount( strandIndex )

		result = []

		# Random displacement of the strand tip to avoid strands in unstable equilibrium
		random.seed( hair.GetHair().GetStrandId( strandIndex ) )
		maxOffsetX = random.uniform( -maxRandomOffset, maxRandomOffset )
		maxOffsetY = random.uniform( -maxRandomOffset, maxRandomOffset )

		strandPoints = hair.GetStrandPointsInWorldCoordinates( strandIndex )
		# Should not have negative radius
		radiusMultipliers = self.GetRampAndChannelMultipliers( hair.GetHair(), strandIndex, self.params.particleRadiusCurve, self.params.particleRadiusChannel, 0 )
		# Can not have 0-mass particles apart from roots, hence the lower limit
		massMultipliers = self.GetRampAndChannelMultipliers( hair.GetHair(), strandIndex, self.params.massCurve, self.params.massChannel, 0.01 )

		strandParticlePositions = []
		segmentInertiaMoments = []
		lastMass = 0
		lastPositions = []

		for strandVertexIndex in range( strandPointCount ):
			# First particles have infinite mass to preserve root position/orientation
			mass = self.params.massPerVertex / self.params.latticeCount if strandVertexIndex > self.params.lastFixedVertexIndex or self.propagatedData.isPropagated else 0
			if mass != 0 and massMultipliers is not None:
				mass *= massMultipliers[strandVertexIndex]
			radius = self.params.particleRadius
			if radiusMultipliers is not None:
				radius *= radiusMultipliers[strandVertexIndex]

			positions = self.GetLatticeNodePositions( hair, strandIndex, strandVertexIndex, strandPoints )

			# Gradually displace particles from root to tip
			if strandVertexIndex > 1:
				scaleFactor =  ( strandVertexIndex - 1.0 ) / ( strandPointCount - 2.0 )
				for pos in positions:
					pos[0] += maxOffsetX * scaleFactor
					pos[2] += maxOffsetY * scaleFactor

			strandParticlePositions.extend( positions )
			# Root particle does not need to be an elaston; segment orientations are kept by subsequent particles
			needsOrientation = useElastons and strandVertexIndex > 0
			particleType = moov.ParticleType.PBD_Elaston if needsOrientation else moov.ParticleType.PBD_Particle
			result.extend( [ moov.ParticleDescription( type = particleType, x = pos, mass = mass, radius = radius, collisionGroup = self.params.particlesCollisionGroup ) for pos in positions ] )

			if useElastons:
				# Compute moment of inertia to get the same behaviour as the non-elaston Cosserat model
				sumInvMasses = ( 0 if mass == 0 else 1/mass ) + ( 0 if lastMass == 0 else 1/lastMass )
				segmentInertiaMoments.extend( [0] * self.params.latticeCount if strandVertexIndex == 0 or sumInvMasses == 0 \
					else [( pos - lastPos ).LengthSquared() / sumInvMasses for pos, lastPos in zip( positions, lastPositions)] )
				lastMass = mass
				lastPositions = positions

		if useElastons:
			orientations = self.ComputeInitialOrientationsFromPositions_Moov( strandParticlePositions, latticeCount = self.params.latticeCount, addUnityAsFirstElement = True )
			for particle, orientation, inertia in zip( result, orientations, segmentInertiaMoments ):
				particle.rotation = orientation
				particle.inertiaMoment = moov.Vector3( inertia, inertia, inertia )

		if self.params.useRootHolder and not self.propagatedData.isPropagated:
			# Root holder does not need elastons since it uses distance constraints
			result.extend( self.GetRootHolderParticleDescriptions( result ) )

		return result

	def GetRootHolderParticleDescriptions( self, particleDescriptions, particleType = moov.ParticleType.PBD_Particle ):
		"""Generates root holder particle descriptions. Positions are calculated as linear interpolation between last fixed and first moving vertices."""

		firstIndex = self.params.latticeCount * self.params.lastFixedVertexIndex
		if firstIndex + 2 * self.params.latticeCount > len( particleDescriptions ):
			return []
			#raise ValueError( 'Cannot add root holder: strand has no moving vertices' )

		positions = [ ( 1.0 - self.params.rootHolderPosition ) * particleDescriptions[index].x + self.params.rootHolderPosition * particleDescriptions[index + self.params.latticeCount].x
			for index in range( firstIndex, firstIndex + self.params.latticeCount ) ]

		return [moov.ParticleDescription( type = particleType, x = pos, mass = 0, collisionGroup = self.params.particlesCollisionGroup ) for pos in positions ]


	@staticmethod
	def GetCosseratConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, stiffnessSS = 1.0, stiffnessBT = 1.0, useCompliantConstraints = False, useElastons = False, stretchingMultipliers = None, bendingMultipliers = None ):

		result = []
		particleCount = len( particleIds )
		if particleCount < 3 * latticeCount:
			return result

		ssStiffness = ClampedStiffness( stiffnessSS, useCompliantConstraints )
		btStiffness = ClampedStiffness( stiffnessBT, useCompliantConstraints )
		ssConstraintType =  ( moov.ConstraintType.PBD_ElastonStretchShear if not useCompliantConstraints else moov.ConstraintType.XPBD_ElastonStretchShear ) if useElastons \
			else ( moov.ConstraintType.PBD_StretchShear if not useCompliantConstraints else moov.ConstraintType.XPBD_StretchShear )
		btConstraintType = ( moov.ConstraintType.PBD_ElastonBendTwist if not useCompliantConstraints else moov.ConstraintType.XPBD_ElastonBendTwist ) if useElastons \
			else ( moov.ConstraintType.PBD_BendTwist if not useCompliantConstraints else moov.ConstraintType.XPBD_BendTwist )

		# The elaston model uses the second elaston of a rod segment for orientation
		for index in range( startIndex, particleCount - latticeCount, latticeCount ):
			indices = [index, index + latticeCount, index + latticeCount] if useElastons else [index, index, index + latticeCount]
			if index > 0 and not useElastons:
			   indices[0] = index - latticeCount
			particles = [particleIds[i] for i in indices]
			rampedStiffness = stretchingMultipliers[index // latticeCount] * ssStiffness if stretchingMultipliers is not None else ssStiffness
			result.append( moov.ConstraintDescription( particleIds = particles, type = ssConstraintType, stiffness = rampedStiffness ) )

		for index in range( startIndex, particleCount - 2 * latticeCount, latticeCount ):
			indices = [index + latticeCount, index + 2 * latticeCount] if useElastons else [index, index, index + latticeCount, index + 2 * latticeCount]
			if index > 0 and not useElastons:
			   indices[0] = index - latticeCount
			particles = [particleIds[i] for i in indices]
			rampedStiffness = bendingMultipliers[index // latticeCount] * btStiffness if bendingMultipliers is not None else btStiffness
			result.append( moov.ConstraintDescription( particleIds = particles, type = btConstraintType, stiffness = rampedStiffness ) )

		return result


	@staticmethod
	def GetDistanceBendingConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, stiffnessDistance = 1.0, stiffnessBending = 1.0, useCompliantConstraints = False, stretchingMultipliers = None, bendingMultipliers = None ):

		result = []
		particleCount = len( particleIds )
		if particleCount < 3 * latticeCount:
			return result

		result = StrandModel.GetLatticeLongitudinalConstraintDescriptions( particleIds, latticeCount, startIndex, stiffnessDistance, stretchingMultipliers, useCompliantConstraints )

		if stiffnessBending != 0:
			bendingStiffness = ClampedStiffness( stiffnessBending, useCompliantConstraints )
			bendingConstraintType = moov.ConstraintType.PBD_Bending if not useCompliantConstraints else moov.ConstraintType.XPBD_Bending
			for index in range( startIndex, particleCount - 2 * latticeCount, latticeCount ):
				rampedBendingStiffness = bendingMultipliers[index // latticeCount] * bendingStiffness if bendingMultipliers is not None else bendingStiffness
				indices = [index, index + latticeCount, index + 2 * latticeCount]
				for nodeIndex in range( latticeCount ):
					particles = [particleIds[i + nodeIndex] for i in indices]
					result.append( moov.ConstraintDescription( particleIds = particles, type = bendingConstraintType, stiffness = rampedBendingStiffness ) )

		return result

	@staticmethod
	def GetDistanceOnlyConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, stiffnessDistance = 1.0, stiffnessBending = 1.0, useCompliantConstraints = False, stretchingMultipliers = None, bendingMultipliers = None ):
		"""Creates distance constraints along the strand to account for both stretching and bending. If stiffnessBending is zero, creates only stretching constraints."""
		result = []
		particleCount = len( particleIds )
		if particleCount < 3 * latticeCount:
			return result

		result = StrandModel.GetLatticeLongitudinalConstraintDescriptions( particleIds, latticeCount, startIndex, stiffnessDistance, stretchingMultipliers, useCompliantConstraints )

		if stiffnessBending != 0:
			bendingStiffness = ClampedStiffness( stiffnessBending, useCompliantConstraints )
			bendingConstraintType = moov.ConstraintType.PBD_Distance if not useCompliantConstraints else moov.ConstraintType.XPBD_Distance
			for index in range( startIndex, particleCount - 2 * latticeCount, latticeCount ):
				rampedBendingStiffness = bendingMultipliers[index // latticeCount] * bendingStiffness if bendingMultipliers is not None else bendingStiffness
				indices = [index, index + 2 * latticeCount]
				for nodeIndex in range( latticeCount ):
					particles = [particleIds[i + nodeIndex] for i in indices]
					result.append( moov.ConstraintDescription( particleIds = particles, type = bendingConstraintType, stiffness = rampedBendingStiffness ) )

		return result

	@staticmethod
	def GetLatticeTransverseConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, stiffnessLattice = 1.0, useCompliantConstraints = False ):
		"""Creates distance constraints between the particles belonging to each lattice node (hair vertex)."""
		result = []
		particleCount = len( particleIds )
		if particleCount < latticeCount or latticeCount < 2:
			return result

		stiffness = ClampedStiffness( stiffnessLattice, useCompliantConstraints )
		constraintType = moov.ConstraintType.PBD_Distance if not useCompliantConstraints else moov.ConstraintType.XPBD_Distance
		nodeStartOffset = range( latticeCount )
		nodeEndOffset = nodeStartOffset[1:] + nodeStartOffset[:1]
		for index in range( startIndex, particleCount, latticeCount ):
			for nodeStartIndex, nodeEndIndex in zip( nodeStartOffset, nodeEndOffset ):
				indices = [index + nodeStartIndex, index + nodeEndIndex]
				particles = [particleIds[i] for i in indices]
				result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = stiffness ) )

		return result

	@staticmethod
	def GetLatticeLongitudinalConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, stiffnessDistance = 1.0, stiffnessMultipliers = None, useCompliantConstraints = False ):
		"""Creates distance constraints along the strand lattice. Lattice criss-cross happens if stiffness is larger than 1."""
		result = []

		particleCount = len( particleIds )
		if particleCount < 3 * latticeCount:
			return result

		stiffness = ClampedStiffness( stiffnessDistance, useCompliantConstraints )
		constraintType = moov.ConstraintType.PBD_Distance if not useCompliantConstraints else moov.ConstraintType.XPBD_Distance

		for index in range( startIndex, particleCount - latticeCount, latticeCount ):
			latticeStartOffset = range( latticeCount )
			latticeEndOffsetRanges = [range( latticeCount )]
			if latticeCount > 1:
				latticeEndOffsetRanges.append( latticeStartOffset[1:] + latticeStartOffset[:1] )
			if latticeCount > 2:
				latticeEndOffsetRanges.append( latticeStartOffset[-1:] + latticeStartOffset[:-1] )
			rampedStiffness = stiffnessMultipliers[index // latticeCount] * stiffness if stiffnessMultipliers is not None else stiffness

			for offsetIndex in range( latticeCount ):
				for endRange in latticeEndOffsetRanges:
					indices = [index + latticeStartOffset[offsetIndex], index + endRange[offsetIndex] + latticeCount]
					particles = [particleIds[i] for i in indices]
					result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = rampedStiffness ) )

		return result

	@staticmethod
	def GetStretchLimitingConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, stiffnessDistance = 1.0 ):
		"""Creates long-range distance constraints along the strand to limit stretch."""
		result = []

		particleCount = len( particleIds )
		# Allow 10% stretching before long-range constraints kick in
		stiffnessVector = moov.Vector3( stiffnessDistance, 0, 1.1 )
		constraintType = moov.ConstraintType.XPBD_DistanceX

		# Add constraints between second strand vertex and lead dynamic particles at each subsequent vertex
		firstParticleIndex = startIndex
		for index in range( startIndex + latticeCount, particleCount - latticeCount, latticeCount ):
			for latticeEndOffset in range( latticeCount ):
				indices = [firstParticleIndex, index + latticeEndOffset]
				particles = [particleIds[i] for i in indices]
				result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = stiffnessVector ) )

		return result

	@staticmethod
	def GetNthNeighborConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, neighborCount = 2, stiffnessDistance = 1.0, useCompliantConstraints = False ):
		"""Creates n-th neighbor distance constraints along the strand."""
		result = []

		particleCount = len( particleIds )
		stiffness = ClampedStiffness( stiffnessDistance, useCompliantConstraints )
		constraintType = moov.ConstraintType.PBD_Distance if not useCompliantConstraints else moov.ConstraintType.XPBD_Distance

		# Add constraints between nth neighbors. First neighborCount particles should also be fixed to the root to avoid artefacts
		for index in range( startIndex - ( neighborCount - 1 ) * latticeCount, particleCount - neighborCount * latticeCount, latticeCount ):
			firstIndex = max( index, startIndex )
			indices = [firstIndex, index + latticeCount * neighborCount]
			particles = [particleIds[i] for i in indices]
			result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = stiffness ) )

		return result

	@staticmethod
	def GetLongRangeConstraintDescriptions( particleIds, latticeCount = 1, startIndex = 0, layerCount = 1, stiffnessDistance = 1.0, stiffnessMultiplier = 1.0, useCompliantConstraints = False ):
		"""Creates multilayer long-range distance constraints along the strand. Layer n contains constraints between 2^n-th neighbors."""
		result = []

		particleCount = len( particleIds )
		constraintType = moov.ConstraintType.PBD_Distance if not useCompliantConstraints else moov.ConstraintType.XPBD_Distance

		neighborCount = 1
		stretchingStiffness = stiffnessDistance
		for layer in range( layerCount ):
			neighborCount *= 2
			stretchingStiffness *= stiffnessMultiplier
			stiffness = ClampedStiffness( stretchingStiffness, useCompliantConstraints )
			# Add constraints between nth neighbors
			for index in range( startIndex, particleCount - neighborCount * latticeCount, latticeCount ):
				indices = [index, index + latticeCount * neighborCount] 
				particles = [particleIds[i] for i in indices]
				result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = stiffness ) )

		return result

	@staticmethod
	def GetPropagatedRootConstraintDescriptions( propagatedStrand, sourceStrand, modelType, latticeCount = 1, stiffnessDistance = 1.0, stiffnessBending = 1.0, useCompliantConstraints = False ):
		result = []
		sourceIndex1 = propagatedStrand.propagatedData.sourcePointIndex1
		sourceIndex2 = propagatedStrand.propagatedData.sourcePointIndex2
		sourceParticleIds1 = sourceStrand.GetVertexParticleIds( sourceIndex1 )
		sourceParticleIds2 = sourceStrand.GetVertexParticleIds( sourceIndex2 )
		sourceParticleIds0 = sourceStrand.GetVertexParticleIds( sourceIndex1 - 1 ) if sourceIndex1 > 0 else sourceParticleIds1
		rootParticleIds1 = propagatedStrand.GetVertexParticleIds( 0 )
		rootParticleIds2 = propagatedStrand.GetVertexParticleIds( 1 )
		#print( 'strand: {0}, source {1} {2}, particle size {3}'.format( strandFirstIndex, sourceFirstIndex1, sourceFirstIndex2, len( self.particleData ) ) )
		for latticeIndex in range( 0, latticeCount ):
			rootId1 = rootParticleIds1[latticeIndex]
			rootId2 = rootParticleIds2[latticeIndex]
			sourceId1 = sourceParticleIds1[latticeIndex]
			sourceId2 = sourceParticleIds2[latticeIndex]
			#print( rootId1, rootId2, sourceId1, sourceId2 )

			if modelType == ModelType.DistanceOnly or modelType == ModelType.DistanceBending:
				result.append( DistanceConstraint( sourceId1, rootId1, stiffnessDistance, useCompliantConstraints ) )
				result.append( DistanceConstraint( rootId1, rootId2, stiffnessDistance, useCompliantConstraints ) )
				#result.append( BendingConstraint( sourceId1, rootId1, rootId2, stiffnessBending ) )
				# Use distance constraints for bending to avoid problems with coinciding particle positions
				result.append( DistanceConstraint( sourceId1, rootId2, stiffnessBending, useCompliantConstraints ) )

				if sourceId1 != sourceId2:
					result.append( DistanceConstraint( sourceId2, rootId1, stiffnessDistance, useCompliantConstraints ) )
					#result.append( BendingConstraint( sourceId2, rootId1, rootId2, stiffnessBending ) )
					# Use distance constraints for bending to avoid problems with coinciding particle positions
					result.append( DistanceConstraint( sourceId2, rootId2, stiffnessBending, useCompliantConstraints ) )

			elif modelType == ModelType.Cosserat or modelType == ModelType.CosseratDistance:
				result.append( ElastonStretchShearConstraint( [sourceId1, rootId1, sourceId2], stiffnessBending, useCompliantConstraints ) )
				result.append( ElastonStretchShearConstraint( [rootId1, sourceId2, sourceId2], stiffnessBending, useCompliantConstraints ) )
				result.append( ElastonBendTwistConstraint( [sourceId2, rootId2], stiffnessBending, useCompliantConstraints ) )
				# Enable this for additional bending stiffness at the roots
				#if sourceId1 != sourceId2 and sourceIndex1 != 0:
				#	result.append( ElastonBendTwistConstraint( [sourceId1, rootId2], stiffnessBending, useCompliantConstraints ) )
				if modelType == ModelType.CosseratDistance:
					result.append( DistanceConstraint( sourceId1, rootId1, stiffnessDistance, useCompliantConstraints ) )
					if sourceId1 != sourceId2:
						result.append( DistanceConstraint( sourceId2, rootId1, stiffnessDistance, useCompliantConstraints ) )


		return result


	@staticmethod
	def GetRootHolderConstraintDescriptions( particleIds, latticeCount = 1, lastFixedVertexIndex = 0, stiffnessRoot = 1.0, useCompliantConstraints = False ):
		"""Creates distance constraints for the root holder."""
		result = []
		particleCount = len( particleIds )
		if particleCount - lastFixedVertexIndex < 3 * latticeCount:
			return result

		stiffness = ClampedStiffness( stiffnessRoot, useCompliantConstraints )
		constraintType = moov.ConstraintType.PBD_Distance if not useCompliantConstraints else moov.ConstraintType.XPBD_Distance

		rootHolderStartIndex = particleCount - latticeCount

		latticeStartOffset = range( latticeCount )
		latticeEndOffsetRanges = [range( latticeCount )]
		if latticeCount > 1:
			latticeEndOffsetRanges.append( latticeStartOffset[1:] + latticeStartOffset[:1] )
		if latticeCount > 2:
			latticeEndOffsetRanges.append( latticeStartOffset[-1:] + latticeStartOffset[:-1] )
			
		for offsetIndex in range( latticeCount ):
			for endRange in latticeEndOffsetRanges:
				indices = [lastFixedVertexIndex + latticeStartOffset[offsetIndex], rootHolderStartIndex + endRange[offsetIndex]] 
				particles = [particleIds[i] for i in indices]
				result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = stiffness ) )
				indices = [rootHolderStartIndex + latticeStartOffset[offsetIndex], lastFixedVertexIndex + endRange[offsetIndex] + latticeCount] 
				particles = [particleIds[i] for i in indices]
				result.append( moov.ConstraintDescription( particleIds = particles, type = constraintType, stiffness = stiffness ) )

		return result

	@staticmethod
	def ComputeInitialOrientationsFromPositions( positions, latticeCount = 1, initialDirection = None, addUnityAsFirstElement = True ):

		result = [moov.Quaternion( 1, [0, 0, 0] )] * latticeCount  if addUnityAsFirstElement else []
		latticeSegmentVectors = [Math.Vector3f( initialDirection ) if initialDirection is not None else Math.Vector3f( 0, 0, 1 )] * latticeCount
		latticeOrientations = [Math.Quaternion( 1, 0, 0, 0 )] * latticeCount

		segmentCount = len( positions ) // latticeCount

		for segmentIndex in range( 1, segmentCount ):
			positionIndex = segmentIndex * latticeCount
			for latticeIndex in range( latticeCount ):
				segmentVector = Math.Vector3f( positions[positionIndex + latticeIndex] - positions[positionIndex + latticeIndex - latticeCount] )
				deltaRotation = Math.Quaternion.FromTwoVectors( latticeSegmentVectors[latticeIndex], segmentVector )
				orientation = deltaRotation * latticeOrientations[latticeIndex]
				latticeSegmentVectors[latticeIndex] = segmentVector
				latticeOrientations[latticeIndex] = orientation
				result.append( moov.Quaternion( orientation.w, [orientation.x, orientation.y, orientation.z] ) )

		return result

	@staticmethod
	def ComputeInitialOrientationsFromPositions_Moov( positions, latticeCount = 1, initialDirection = None, addUnityAsFirstElement = True ):

		if latticeCount > 1:
			return StrandModel.ComputeInitialOrientationsFromPositions( positions, latticeCount, initialDirection, addUnityAsFirstElement )

		result = [moov.Quaternion( 1, [0, 0, 0] )] * latticeCount  if addUnityAsFirstElement else []
		newPositions = positions if initialDirection is None else [initialDirection] + positions
		result.extend( moov.CosseratOrientationsFromPositions( newPositions, initialDirection is not None ) )

		return result

	@staticmethod
	def ComputeOrientationsFromPositions( positions, oldOrientations, latticeCount = 1 ):

		result = []
		segmentCount = len( positions ) // latticeCount

		for segmentIndex in range( segmentCount - 1 ):
			positionIndex = segmentIndex * latticeCount
			for latticeIndex in range( latticeCount ):
				oldOrientation = Math.Quaternion( oldOrientations[positionIndex + latticeIndex] )
				oldSegmentVector = oldOrientation.ToMatrix() * Math.Vector3f( 0, 0, 1 )
				segmentVector = Math.Vector3f( positions[positionIndex + latticeIndex + latticeCount] - positions[positionIndex + latticeIndex] )
				deltaRotation = Math.Quaternion.FromTwoVectors( oldSegmentVector, segmentVector )
				orientation = deltaRotation * oldOrientation
				result.append( moov.Quaternion( orientation.w, [orientation.x, orientation.y, orientation.z] ) )

		return result

	@staticmethod
	def ComputeOrientationsFromPositions_Moov( positions, oldOrientations, latticeCount = 1 ):
		# Append identity orientation for initial particle
		result = [moov.Quaternion( 1, [0, 0, 0] )] * latticeCount
		#result.extend( StrandModel.ComputeOrientationsFromPositions( positions, oldOrientations[latticeCount:], latticeCount ) )
		if latticeCount > 1:
			result.extend( StrandModel.ComputeOrientationsFromPositions( positions, oldOrientations[latticeCount:], latticeCount ) )
		else:
			result.extend( moov.CosseratRootOrientations( positions, oldOrientations[latticeCount:] ) )
		return result


class HairModel:
	def __init__( self, solver ):
		self.solver = solver
		self.params = ModelParameters()
		self.particles = None
		self.dynamicParticles = None
		self.rootParticles = None
		self.meshParticles = None
		self.meshParticleIds = []
		self.groupHolderConstraints = None
		self.attachmentConstraintsList = []
		self.hair = None
		self.strands = []
		# list of strand indices sorted by strand ids
		self.sortedStrandsList = None
		# list of non-propagated strands
		self.sortedFixedRootStrandsList = []
		# maps self.strands indices to hair strand indices
		self.strandToHairIndex = None
		# maps hair strand indices to self.strands indices
		self.hairToStrandIndex = None
		self.strandGroupTester = None


	def SetSolverParameters( self, positionIterCount = None, velocityIterCount = None, contactIterCount = None, maxSpeed = None, minSpeed = None, collisionTolerance = None,
							contactStiffnessParticleRigidBody = None, contactStiffnessRigidBody = None,
							pbdParticleFrictionCoefficient = None, pbdParticleRestitutionCoefficient = None ):
		params = self.solver.GetParameters()
		if positionIterCount is not None:
			params.positionIterCount = positionIterCount
		if velocityIterCount is not None:
			params.velocityIterCount = velocityIterCount
		if contactIterCount is not None:
			params.contactIterCount = contactIterCount
		if maxSpeed is not None:
			params.maxSpeed = maxSpeed
		if minSpeed is not None:
			params.minSpeed = minSpeed
		if collisionTolerance is not None:
			params.collisionTolerance = collisionTolerance
		if contactStiffnessParticleRigidBody is not None:
			params.contactStiffnessParticleRigidBody = contactStiffnessParticleRigidBody
		if contactStiffnessRigidBody is not None:
			params.contactStiffnessRigidBody = contactStiffnessRigidBody
		if pbdParticleFrictionCoefficient is not None:
			params.pbdParticleFrictionCoefficient = pbdParticleFrictionCoefficient
		if pbdParticleRestitutionCoefficient is not None:
			params.pbdParticleRestitutionCoefficient = pbdParticleRestitutionCoefficient
		self.solver.SetParameters( params )

	def SetHair( self, hair ):
		self.hair = hair
		if self.strandGroupTester is not None:
			self.strandGroupTester.Set( hair.GetHair() )

	def SetStrandGroupSet( self, strandGroupSet ):
		if self.strandGroupTester is None:
		   self.strandGroupTester = ox.StrandGroupApplicationTester()
		self.strandGroupTester.SetAllowedGroupSet( strandGroupSet )

	def InitializeStrands( self ):
		strandCount = self.hair.GetStrandCount()
		self.strands = []
		for strandIndex in range( strandCount ):
			if self.strandGroupTester is not None and not self.strandGroupTester.IsOperatorApplicableToStrand( strandIndex ):
				continue
			self.strands.append( StrandModel( self.hair, strandIndex, self.params ) )
		self.ValidateStrandIndexMaps()

	def ValidateStrandIndexMaps( self ):
		strandCount = self.hair.GetStrandCount()
		self.sortedFixedRootStrandsList = []
		self.sortedStrandsList = []
		self.strandToHairIndex = []
		self.hairToStrandIndex = {}
		strandIdToIndexMap = {}
		fixedRootStrandIdToIndexMap = {}
		selfStrandIndex = 0
		for strandIndex in range( strandCount ):
			if self.strandGroupTester is not None and not self.strandGroupTester.IsOperatorApplicableToStrand( strandIndex ):
				continue
			strandId = self.hair.GetHair().GetStrandId( strandIndex )
			if strandId in strandIdToIndexMap.keys():
				print( "Warning: duplicated strand ids found: {0} for strand index {1} {2}".format( strandId, strandIdToIndexMap[strandId], strandIndex) )
			strandIdToIndexMap[strandId] = selfStrandIndex
			# create non-propagated strands list for roots update
			if not self.strands[selfStrandIndex].IsPropagated():
				fixedRootStrandIdToIndexMap[strandId] = selfStrandIndex
			self.strandToHairIndex.append( strandIndex )
			self.hairToStrandIndex[strandIndex] = selfStrandIndex
			selfStrandIndex += 1

		self.sortedStrandsList = [ strandIdToIndexMap[id] for id in sorted( strandIdToIndexMap.keys() ) ]
		self.sortedFixedRootStrandsList = [ fixedRootStrandIdToIndexMap[id] for id in sorted( fixedRootStrandIdToIndexMap.keys() ) ]
		self.sortedStrandsList = self.GetSortedStrandListByPropagationDepth()

	def GetSortedStrandListByPropagationDepth( self ):
		# Hair updates have to be done following the propagation levels to ensure strand transforms are correctly updated by Ornatrix (issue #2984).
		# Easiest way to deal with this is to create and update particles using the sorted strands list provided by this function.
		result = []
		result.extend( self.sortedFixedRootStrandsList )
		areAllStrandsAdded = True
		maxPropagationDepth = 10
		propagationDepth = 1
		remainingStrandsList = self.sortedStrandsList
		while propagationDepth < maxPropagationDepth:
			newRemainingStrandsList = [ index for index in remainingStrandsList if index not in result ]
			for strandIndex in newRemainingStrandsList:
				propagatedData = self.strands[strandIndex].propagatedData
				if propagatedData.isPropagated and propagatedData.sourceStrand in result:
					result.append( strandIndex )
			propagationDepth += 1
			remainingStrandsList = newRemainingStrandsList
		if len( result ) != len( self.sortedStrandsList ):
			print( "Could not sort strands; too many propagation levels" )
			return self.sortedStrandsList
		return result

	def ClearStrands( self ):
		self.sortedFixedRootStrandsList = []
		# Clear all constraints first for faster particle deletion (Moov checks if a particle is constrained before removing it)
		for strand in self.strands:
			strand.ClearConstraints( self.solver )
		for strand in self.strands:
			strand.ClearParticles( self.solver )
		self.strands = []
		self.sortedStrandsList = None
		self.particles = None
		self.dynamicParticles = None
		self.rootParticles = None

	def ClearSolverObjects( self ):
		# TODO: For a shared solver containing several hair objects, solver.Clear() should not be called
		# Remove this call when set deletion is reasonably fast
		self.solver.Clear()
		self.ReleaseAttachment()
		if self.groupHolderConstraints is not None:
			self.solver.RemoveConstraints( self.groupHolderConstraints )
			self.groupHolderConstraints = None
		self.ClearStrands()
		self.meshParticleIds = []
		self.meshParticles = None


	def CreateHairModel( self ):
		"""Creates strand particles and constraints in the solver."""
		if self.hair is None:
			raise RuntimeError( 'Hair model not initialized.' )

		strandCount = len( self.strands )

		allParticleSets = []
		constraintCount = 0

		# Make sure the initial random displacements of strand tips are repeatable
		random.seed( 1 )

		# State captures rely on solver ids; strand-id dependent ordering ensures the same solver ids are created each time for the same hair object.
		for strandIndex in self.sortedStrandsList:
			strand = self.strands[strandIndex]
			strand.CreateSolverObjects( self.solver, self.hair, self.strandToHairIndex[strandIndex] )
			allParticleSets.append( strand.particleSet )
			constraintCount += len( strand.constraintSet )

		# Create particle sets
		self.particles = self.solver.JoinParticleSets( allParticleSets )
		self.dynamicParticles = self.solver.SelectParticles( self.particles, lambda pd: pd.mass != 0, moov.ParticleInformation.Mass )
		self.rootParticles = self.solver.SelectParticles( self.particles, lambda pd: pd.mass == 0, moov.ParticleInformation.Mass )

		# Connect propagated strands
		for strandIndex in self.sortedStrandsList:
			if self.strands[strandIndex].IsPropagated():
				self.ConnectPropagatedStrand( strandIndex )
				constraintCount += len( self.strands[strandIndex].propagatedRootConstraintSet )

		self.ResetMeshes()
		self.ResetGroupHolder()
		if self.groupHolderConstraints is not None:
			constraintCount += len( self.groupHolderConstraints )

		print( "Hair model created: {0} strands, {1} particles, {2} constraints, {3} meshes".format( strandCount, len( self.particles ), constraintCount, len( self.meshParticles ) ) )

	def ResetMeshes( self ):
		if self.meshParticles is not None:
			self.solver.RemoveParticles( self.meshParticles )
			self.meshParticleIds = []
		self.meshParticles = self.solver.CreateParticles( self.GetMeshParticleDescriptions(), moov.ParticleInformation.All, self.meshParticleIds )

	def ResetGroupHolder( self ):
		# Create group holders
		if self.groupHolderConstraints is not None:
			self.solver.RemoveConstraints( self.groupHolderConstraints )
			self.groupHolderConstraints = None
		if self.params.useGroupHolder:
			groupHolderDescriptions = self.GetGroupHolderConstraintDescriptions()
			self.groupHolderConstraints = self.solver.CreateConstraints( groupHolderDescriptions, moov.ConstraintInformation.All )

	def ConnectPropagatedStrand( self, strandIndex ):
		"""If a strand is propagated, creates constraints to connect it with the source strand."""
		propagatedStrand = self.strands[strandIndex]

		if not propagatedStrand.IsPropagated():
			return

		sourceStrandIndex = self.hairToStrandIndex[propagatedStrand.propagatedData.sourceStrand]
		sourceStrand = self.strands[sourceStrandIndex]
		#stiffness = self.params.stretchingStiffness / self.params.latticeCount

		constraints = StrandModel.GetPropagatedRootConstraintDescriptions( propagatedStrand, sourceStrand, self.params.modelType, self.params.latticeCount, self.params.stretchingStiffness, self.params.bendingStiffness, self.params.useCompliantConstraints )
		propagatedStrand.propagatedRootConstraintSet = self.solver.CreateConstraints( constraints, moov.ConstraintInformation.All )


	def GetMeshParticleDescriptions( self ):
		"""Creates particle descriptions for meshes in the simulator"""
		meshParticleData = []
		meshes = self.hair.GetPolygonMeshes()
		for meshIndex in range( len( meshes ) ):
			mesh = meshes[meshIndex]
			pd = moov.ParticleDescription( x = mesh.position, v = [0, 0, 0], rotation = moov.Quaternion( mesh.rotationQReal, mesh.rotationQImag ), mass = 0, type = moov.ParticleType.CollisionMesh )
			pd.collisionGroup = self.params.baseMeshCollisionGroup if meshIndex == 0 else self.params.collisionMeshesCollisionGroup
			pd.polygonMesh = mesh.handle
			pd.frictionCoefficient = self.params.meshFrictionCoefficient
			pd.restitutionCoefficient = self.params.meshRestitutionCoefficient
			meshParticleData.append( pd )
		return meshParticleData

	def EnableCollisions( self, enableParticleBaseMeshCollisions = None, enableParticleMeshCollisions = None, enableCapsuleCollisions = None ):
		if enableParticleBaseMeshCollisions is not None:
			if enableParticleBaseMeshCollisions:
				self.solver.AddCollidingGroupPair( self.params.particlesCollisionGroup, self.params.baseMeshCollisionGroup )
			else:
				self.solver.RemoveCollidingGroupPair( self.params.particlesCollisionGroup, self.params.baseMeshCollisionGroup )
		if enableParticleMeshCollisions is not None:
			if enableParticleMeshCollisions:
				self.solver.AddCollidingGroupPair( self.params.particlesCollisionGroup, self.params.collisionMeshesCollisionGroup )
			else:
				self.solver.RemoveCollidingGroupPair( self.params.particlesCollisionGroup, self.params.collisionMeshesCollisionGroup )
		if enableCapsuleCollisions is not None:
			if enableCapsuleCollisions:
				self.solver.AddCollidingGroupPair( self.params.capsuleCollisionGroup, self.params.capsuleCollisionGroup )
			else:
				self.solver.RemoveCollidingGroupPair( self.params.capsuleCollisionGroup, self.params.capsuleCollisionGroup )
			for strand in self.strands:
				strand.ResetCapsules( self.solver, create = enableCapsuleCollisions )

	def GetExternalForces( self, moovParticleSet, step, forceMultiplier = 1, time = 0 ):
		"""Returns a list of moov.Vector3 forces for all particles in the set"""
		particleDescriptions = self.solver.GetParticleInformation( moovParticleSet, moov.ParticleInformation( moov.ParticleInformation.Position | moov.ParticleInformation.Velocity ) )
		positions = []
		velocities = []
		for pd in particleDescriptions:
			positions.append( ox.Vector3( pd.x[0], pd.x[1], pd.x[2] ) )
			velocities.append( ox.Vector3( pd.v[0], pd.v[1], pd.v[2] ) )
		hostForces = self.hair.GetExternalForces( positions, velocities, step, time )
		stepMultiplier = float( forceMultiplier )
		return [ moov.Vector3( force[0], force[1], force[2] ) * stepMultiplier for force in hostForces ]

	def UpdateHair( self, outputHair ):
		solverParticleData = self.solver.GetParticleInformation( self.dynamicParticles, moov.ParticleInformation( moov.ParticleInformation.Position ) )
		moovVertexIndex = 0
		for strandIndex in self.sortedStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			strandVertexCount = outputHair.GetStrandPointCount( strandHairIndex )
			startIndex = 0 if self.strands[strandIndex].IsPropagated() else self.params.lastFixedVertexIndex + 1
			positions = []
			for index in range( startIndex, strandVertexCount ):
				positions.append( ox.Vector3( solverParticleData[moovVertexIndex].x ) )
				moovVertexIndex += self.params.latticeCount
			outputHair.SetStrandPointsInWorldCoordinates( strandHairIndex, startIndex, strandVertexCount, positions )

	def GetHairParticlePositions( self, minStrandPointIndex = None, maxStrandPointIndex = None, strandIndexList = None ):
		particlePositions = {}
		minStrandPoint = 0 if minStrandPointIndex is None else max( 0, minStrandPointIndex )
		if strandIndexList is None:
			strandIndexList = self.sortedStrandsList
		for strandIndex in strandIndexList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			strandVertexCount = self.hair.GetStrandPointCount( strandHairIndex )
			maxStrandPoint = strandVertexCount if maxStrandPointIndex is None else min( strandVertexCount, maxStrandPointIndex )
			for index in range( minStrandPoint, maxStrandPoint ):
				latticePositions = self.strands[strandIndex].GetLatticeNodePositions( self.hair, strandHairIndex, index )
				particleStartIndex = index * self.params.latticeCount
				for latticeIndex in range( len( latticePositions ) ):
					particlePositions[self.strands[strandIndex].particleIds[particleStartIndex + latticeIndex]] = latticePositions[latticeIndex]
		particlePositionsList = [ particlePositions[id] for id in sorted( particlePositions.keys() ) ]
		return particlePositionsList

	def GetHairVertexPositions( self, hair = None ):
		result = []
		if hair is None:
			hair = self.hair
		for strandIndex in self.sortedStrandsList:
			result.extend( hair.GetStrandPointsInWorldCoordinates( self.strandToHairIndex[strandIndex] ) )
		return result

	def GetHairRootPositions__deprecated( self ):
		particlePositions = []

		firstMovingVertexIndex = self.params.lastFixedVertexIndex + 2 if self.params.useRootHolder else self.params.lastFixedVertexIndex + 1
		positions = self.hair.GetStrandsPointsInWorldCoordinates( self.sortedFixedRootStrandsList, 0, firstMovingVertexIndex )

		for index in range( len( self.sortedFixedRootStrandsList ) ):
			positionIndex = index * firstMovingVertexIndex
			if self.params.useRootHolder:
				rootHolderIndex = positionIndex + firstMovingVertexIndex - 1
				positions[rootHolderIndex] = ( 1.0 - self.params.rootHolderPosition ) * positions[rootHolderIndex - 1] + self.params.rootHolderPosition * positions[rootHolderIndex]

			strandIndex = self.sortedFixedRootStrandsList[index]

			if self.params.latticeCount > 1:
				for particleIndex in range( firstMovingVertexIndex ):
					latticePositions = self.strands[strandIndex].GetLatticeNodePositions( self.hair, self.strandToHairIndex[strandIndex], particleIndex, positions[positionIndex:positionIndex + firstMovingVertexIndex] )
					particlePositions.extend( latticePositions )
				continue

			for particleIndex in range( firstMovingVertexIndex ):
				particlePositions.append( moov.Vector3( positions[positionIndex + particleIndex] ) )

		return particlePositions

	def GetHairRootPositions( self ):
		particlePositions = []

		firstMovingVertexIndex = self.params.lastFixedVertexIndex + 2 if self.params.useRootHolder else self.params.lastFixedVertexIndex + 1

		for strandIndex in self.sortedFixedRootStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			firstMovingVertexIndexForStrand = max( 0, min( firstMovingVertexIndex, self.hair.GetStrandPointCount( strandHairIndex ) ) )
			strandPositions = self.hair.GetStrandPointsInWorldCoordinates( strandHairIndex )
			positions = strandPositions[0:firstMovingVertexIndexForStrand]

			if self.params.useRootHolder:
				rootHolderIndex = firstMovingVertexIndexForStrand - 1
				positions[rootHolderIndex] = ( 1.0 - self.params.rootHolderPosition ) * positions[rootHolderIndex - 1] + self.params.rootHolderPosition * positions[rootHolderIndex]

			if self.params.latticeCount > 1:
				for particleIndex in range( firstMovingVertexIndexForStrand ):
					latticePositions = self.strands[strandIndex].GetLatticeNodePositions( self.hair, strandHairIndex, particleIndex, positions )
					particlePositions.extend( latticePositions )
				continue

			for particleIndex in range( firstMovingVertexIndexForStrand ):
				particlePositions.append( moov.Vector3( positions[particleIndex] ) )

		return particlePositions


	def GetHairRootOrientations( self, newRootPositions, oldRootOrientations = None ):

		if self.params.modelType != ModelType.Cosserat and self.params.modelType != ModelType.CosseratDistance or self.params.lastFixedVertexIndex < 1:
			return None

		firstMovingVertexIndex = self.params.lastFixedVertexIndex + 2 if self.params.useRootHolder else self.params.lastFixedVertexIndex + 1

		orientations = []
		startingIndex = 0
		for strandIndex in self.sortedFixedRootStrandsList:
			firstMovingVertexIndexForStrand = max( 0, min( firstMovingVertexIndex, self.hair.GetStrandPointCount( self.strandToHairIndex[strandIndex] ) ) )
			rootParticlePerStrandCount = self.params.latticeCount * firstMovingVertexIndexForStrand
			positions = newRootPositions[startingIndex:( startingIndex + rootParticlePerStrandCount )]
			if oldRootOrientations is None:
				orientations.extend( StrandModel.ComputeInitialOrientationsFromPositions_Moov( positions, self.params.latticeCount ) )
			else:
				oldOrientations = oldRootOrientations[startingIndex:( startingIndex + rootParticlePerStrandCount )]
				orientations.extend( StrandModel.ComputeOrientationsFromPositions_Moov( positions, oldOrientations, self.params.latticeCount ) )
			startingIndex += rootParticlePerStrandCount

		return orientations


	def GetHairRootOrientationsFromStrandTransform( self ):

		if self.params.modelType != ModelType.Cosserat and self.params.modelType != ModelType.CosseratDistance or self.params.lastFixedVertexIndex < 1:
			return None

		orientations = []
		for strandIndex in self.sortedFixedRootStrandsList:
			orientations.extend( [moov.Quaternion( 1, [0, 0, 0] )] * self.params.latticeCount )
			orientationQuaternionAsList = self.hair.GetStrandOrientationQuaternion( self.strandToHairIndex[strandIndex] )
			orientations.extend( [moov.Quaternion( orientationQuaternionAsList[0], orientationQuaternionAsList[1:4] )] * self.params.latticeCount )

		return orientations


	def GetHairInitialPositionsAndRampMultipliers( self, rampFunction = None, channel = None ):
		strandChannel = StrandChannel( self.hair.GetHair(), channel )
		hasMultipliers = rampFunction is not None or strandChannel.isValidChannel
		particlePositions = []
		rampMultipliers = []
		for strandIndex in self.sortedStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			strandPointPositions = self.hair.GetStrandPointsInWorldCoordinates( strandHairIndex )
			minStrandPoint = 0 if self.strands[strandIndex].IsPropagated() else self.params.lastFixedVertexIndex + 1
			strandVertexCount = len( strandPointPositions )
			for index in range( minStrandPoint, strandVertexCount ):
				# TODO: Obtain actual lattice offsets
				latticeOffsets = [moov.Vector3( 0, 0, 0 )] * self.params.latticeCount
				vertexPosition = strandPointPositions[index]
				# accomodate ox.HostVector3 position lists
				moovVertexPosition = moov.Vector3( vertexPosition[0], vertexPosition[1], vertexPosition[2] )

				multiplier = 1.0
				if hasMultipliers:
					if rampFunction is not None:
						positionAlongStrand = index / ( strandVertexCount - 1 )
						multiplier *= rampFunction.Evaluate( positionAlongStrand )
					if strandChannel.isValidChannel:
						multiplier *= strandChannel.GetValue( strandHairIndex, index )

				for latticeIndex in range( self.params.latticeCount ):
					particlePositions.append( moovVertexPosition + latticeOffsets[latticeIndex] )
					if hasMultipliers:
						rampMultipliers.append( multiplier )
		return particlePositions, rampMultipliers


	def GetValuesTimesRootToTipMultipliers( self, valueList, rampFunction = None, channel = None ):
		strandChannel = StrandChannel( self.hair.GetHair(), channel )
		if ( rampFunction is None and not strandChannel.isValidChannel ) or valueList is None:
			return valueList
		result = []
		valueIndex = 0
		for strandIndex in self.sortedStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			minStrandPoint = 0 if self.strands[strandIndex].IsPropagated() else self.params.lastFixedVertexIndex + 1
			strandVertexCount = self.hair.GetStrandPointCount( strandHairIndex )
			for index in range( minStrandPoint, strandVertexCount ):
				multiplier = 1.0
				if rampFunction is not None:
					positionAlongStrand = float( index ) / ( strandVertexCount - 1 )
					multiplier *= rampFunction.Evaluate( positionAlongStrand )
				if strandChannel.isValidChannel:
					multiplier *= strandChannel.GetValue( strandHairIndex, index )
				for latticeIndex in range( self.params.latticeCount ):
					result.append( valueList[valueIndex] * multiplier )
					valueIndex += 1
		return result


	# Group holder functions

	def GenerateGroupHolderParticlePairs_Random( self, strandGroups, minShellIndex, maxShellIndex, maxCountPerGroup = 1000 ):
		"""Generates particle Id pairs for the group holder selecting random particles from random strands."""
		result = []
		for group in strandGroups.values():
			for i in range( 0, maxCountPerGroup ):
				groupCount = len( group )
				if groupCount < 2:
					continue
				strand1GroupIndex = random.randrange( 0, groupCount )
				strand2GroupIndex = random.randrange( 0, groupCount )
				if strand2GroupIndex == strand1GroupIndex:
					strand2GroupIndex += 1
					strand2GroupIndex %= groupCount
				strand1Index = group[strand1GroupIndex]
				strand1Count = self.hair.GetStrandPointCount( strand1Index )
				strand1Point = random.randrange( max( 0, minShellIndex ), min( strand1Count, maxShellIndex ) )
				strand2Index = group[strand2GroupIndex]
				strand2Count = self.hair.GetStrandPointCount( strand2Index )
				strand2Point = random.randrange( max( 0, minShellIndex ), min( strand2Count, maxShellIndex ) )
				# Return only the lead particle of a node
				# This works well if groupStiffness is smaller than latticeStiffness and/or there are many group holder constraints
				particleId1 = self.strands[strand1Index].GetParticleId( strand1Point, 0 )
				particleId2 = self.strands[strand2Index].GetParticleId( strand2Point, 0 )
				# Order to avoid duplicated pairs
				result.append( ( min( particleId1, particleId2 ), max( particleId1, particleId2 ) ) )
		return set( result )

	def GetGroupHolderConstraintDescriptions( self ):
		"""Creates constraints to hold groups of strands together."""

		strandGroups = {}
		for strandIndex in self.sortedStrandsList:
			groupIndex = self.hair.GetStrandGroup( self.strandToHairIndex[strandIndex] )
			if groupIndex in strandGroups:
				strandGroups[groupIndex].append( strandIndex )
			else:
				strandGroups[groupIndex] = [strandIndex]

		particlePairs = []
		if self.params.groupHolderGenerator == 'Random':
			random.seed( self.params.groupHolderRandomSeed )
			particlePairs = self.GenerateGroupHolderParticlePairs_Random( strandGroups, self.params.groupHolderPosMin, self.params.groupHolderPosMax, self.params.groupHolderMaxGroupCount )

		stiffness = ClampedStiffness( self.params.groupHolderStiffness, self.params.useCompliantConstraints )
		constraintType = moov.ConstraintType.XPBD_Distance if self.params.useCompliantConstraints else moov.ConstraintType.PBD_Distance
		constraints = [ moov.ConstraintDescription( particleIds = particlePair, type = constraintType, stiffness = stiffness ) for particlePair in particlePairs ]

		return constraints

	# Attachment functions

	def GenerateAttachmentParticlePairs_Random( self, pointIndicesByStrandIndex, maxConstraintCount ):
		"""Generates particle Id pairs for attachment selecting random particles from random strands."""
		result = []
		group = pointIndicesByStrandIndex.keys()
		groupCount = len( group )
		if groupCount < 2:
			return result
		for _ in range( maxConstraintCount ):
			strand1GroupIndex = random.randrange( 0, groupCount )
			strand2GroupIndex = random.randrange( 0, groupCount )
			if strand2GroupIndex == strand1GroupIndex:
				strand2GroupIndex += 1
				strand2GroupIndex %= groupCount
			strand1Index = group[strand1GroupIndex]
			strand1Count = len( pointIndicesByStrandIndex[strand1Index] )
			strand1Point = pointIndicesByStrandIndex[strand1Index][random.randrange( 0, strand1Count )]
			strand2Index = group[strand2GroupIndex]
			strand2Count = len( pointIndicesByStrandIndex[strand2Index] )
			strand2Point = pointIndicesByStrandIndex[strand2Index][random.randrange( 0, strand2Count )]
			# use only the lead particle of a lattice node
			particleId1 = self.strands[strand1Index].GetParticleId( strand1Point, 0 )
			particleId2 = self.strands[strand2Index].GetParticleId( strand2Point, 0 )
			# Order to avoid duplicated pairs
			result.append( ( min( particleId1, particleId2 ), max( particleId1, particleId2 ) ) )
		return set( result )

	def CreateAttachment( self, dataGenerator, stiffness, constraintDensity, createOneAttachmentPerObject = False, maxConstraintCount = 10000 ):
		"""Creates constraints holding the attachment together."""
		self.ReleaseAttachment()
		# TODO: use own random seed?
		random.seed( self.params.groupHolderRandomSeed )
		attachmentCount = dataGenerator.GetObjectCount() if createOneAttachmentPerObject else 1
		pointIndicesByStrandList = [{} for _ in range( attachmentCount )]
		evaluationMethod = ox.DataGenerationMethod.IsInsideEach if createOneAttachmentPerObject else ox.DataGenerationMethod.IsInsideAny
		for strandIndex in self.sortedStrandsList:
			particlePositions = self.hair.GetStrandPointsInWorldCoordinates( self.strandToHairIndex[strandIndex] )
			isInside = dataGenerator.Evaluate( particlePositions, evaluationMethod )
			for attachmentIndex in range( attachmentCount ):
				pointIndices = [strandPointIndex for strandPointIndex in range( len( particlePositions ) ) if isInside[strandPointIndex * attachmentCount + attachmentIndex] != 0 ]
				# pointIndices = filter( lambda strandPointIndex: isInside[strandPointIndex] != 0, range( len( particlePositions ) )
				if len( pointIndices ) > 0:
					pointIndicesByStrand = pointIndicesByStrandList[attachmentIndex]
					pointIndicesByStrand[strandIndex] = pointIndices

		for attachmentIndex in range( attachmentCount ):
			pointIndicesByStrand = pointIndicesByStrandList[attachmentIndex]
			strandCount = len( pointIndicesByStrand.keys() )
			if strandCount < 2:
				continue
			particleCount = reduce( lambda x, y: x + len( y ), pointIndicesByStrand.values(), 0 )
			maxConstraintCount = min( maxConstraintCount, int( constraintDensity * particleCount * 5 ) )
			#print( 'Attachment strands: {}, particles: {}, maxConstraintCount: {}'.format( strandCount, particleCount, maxConstraintCount ) )
			particlePairs = self.GenerateAttachmentParticlePairs_Random( pointIndicesByStrand, maxConstraintCount )

			stiffness = ClampedStiffness( stiffness, self.params.useCompliantConstraints )
			constraintType = moov.ConstraintType.XPBD_Distance if self.params.useCompliantConstraints else moov.ConstraintType.PBD_Distance
			constraints = [ moov.ConstraintDescription( particleIds = particlePair, type = constraintType, stiffness = stiffness ) for particlePair in particlePairs ]
			self.attachmentConstraintsList.append( self.solver.CreateConstraints( constraints, moov.ConstraintInformation.All ) )

		attachmentConstraintsCount = reduce( lambda x, y: x + len( y ), self.attachmentConstraintsList, 0 )
		print( 'Created {} attachments with a total of {} constraints'.format( len( self.attachmentConstraintsList ), attachmentConstraintsCount ) )

	def ReleaseAttachment( self ):
		for attachmentConstraints in self.attachmentConstraintsList:
			self.solver.RemoveConstraints( attachmentConstraints )
		self.attachmentConstraintsList = []

	# Functions for working with guides data (only if hair has guides)

	def GetChannelIndex( self, channelName ):
		"""Finds channel index by name"""
		guides = self.hair.GetHair()
		if guides is None:
			return -1

		return guides.GetChannelIndex( ox.StrandDataType.PerVertex, channelName )

	def AddVertexDataChannel( self, channelName ):
		"""Adds a new vertex data channel if a channel with the specified name does not exist. 
		Returns the channel index of either the existing or the newly created channel."""
		channelIndex = self.GetChannelIndex( channelName )
		if channelIndex >= 0:
			return channelIndex

		guides = self.hair.GetHair()
		if guides is None:
			return -1

		channelCount = guides.GetStrandChannelCount( ox.StrandDataType.PerVertex )
		guides.SetStrandChannelCount( ox.StrandDataType.PerVertex, channelCount + 1 )
		guides.SetStrandChannelName( ox.StrandDataType.PerVertex, channelCount, channelName )
		return channelCount


	def GetChannelVertexData( self, channelIndex ):
		"""Retrieves data from a channel."""
		vertexData = []
		guides = self.hair.GetHair()
		if guides is None or channelIndex < 0 or channelIndex >= guides.GetStrandChannelCount( ox.StrandDataType.PerVertex ):
			return vertexData

		for strandIndex in self.sortedStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			strandVertexCount = self.hair.GetStrandPointCount( strandHairIndex )
			for index in range( strandVertexCount ):
				vertexData.append( guides.GetStrandChannelData( channelIndex, strandHairIndex, index ) )

		return vertexData


	def SetChannelVertexData( self, channelIndex,  vertexData ):
		"""Writes data to a channel."""
		guides = self.hair.GetHair()
		if guides is None or channelIndex < 0 or channelIndex >= guides.GetStrandChannelCount( ox.StrandDataType.PerVertex ):
			print( "No hair or bad channel index: hair {0}, channelIndex {1}".format( guides, channelIndex ) )
			return False

		vertexIndex = 0
		for strandIndex in self.sortedStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			strandVertexCount = self.hair.GetStrandPointCount( strandHairIndex )
			for strandPointIndex in range( strandVertexCount ):
				guides.SetStrandChannelData( channelIndex, strandHairIndex, strandPointIndex, vertexData[vertexIndex] )
				vertexIndex += 1
		return True

	def InitChannelVertexData( self, channelIndex,  value ):
		"""Sets all data in the channel to value. Returns True if successful, False otherwise."""
		guides = self.hair.GetHair()
		if guides is None or channelIndex < 0 or channelIndex >= guides.GetStrandChannelCount( ox.StrandDataType.PerVertex ):
			return False

		guides.SetStrandChannelAllData( ox.StrandDataType.PerVertex, channelIndex, 1, value )
		return True

	def GetPerStrandNoiseList( self, perStrandNoise, randomSeed = None ):
		"""Creates a list containing random per-strand noise values for each strand."""
		result = []
		if randomSeed is not None:
			random.seed( randomSeed )

		for strandIndex in self.sortedStrandsList:
			result.append( random.uniform( -perStrandNoise, perStrandNoise ) )

		return result

	def GetPerStrandNoiseForDynamicParticleList( self, perStrandNoise, randomSeed = None ):
		"""Creates a list containing random per-strand noise values for each dynamic particle."""
		vertexData = []
		if randomSeed is not None:
			random.seed( randomSeed )

		for strandIndex in self.sortedStrandsList:
			minStrandPoint = 0 if self.strands[strandIndex].IsPropagated() else self.params.lastFixedVertexIndex + 1
			strandPointCount = self.hair.GetStrandPointCount( self.strandToHairIndex[strandIndex] )
			datum = random.uniform( -perStrandNoise, perStrandNoise )
			for index in range( minStrandPoint, strandPointCount ):
				vertexData.extend( [datum] * self.params.latticeCount )

		return vertexData

	def DynamicParticleToVertexIndexMap( self ):
		"""Returns a list mapping indices of dynamic particles to channel data indices."""
		indexMap = []
		mappedIndex = 0

		for strandIndex in self.sortedStrandsList:
			minStrandPoint = 0 if self.strands[strandIndex].IsPropagated() else self.params.lastFixedVertexIndex + 1
			strandPointCount = self.hair.GetStrandPointCount( self.strandToHairIndex[strandIndex] )
			for index in range( strandPointCount ):
				if index >= minStrandPoint and index < strandPointCount:
					indexMap.extend( [mappedIndex] * self.params.latticeCount )
				mappedIndex += 1

		return indexMap

	def VertexToDynamicParticleIndexMap( self ):
		"""Returns a list mapping channel data indices to indices of dynamic particles. 
		
		Each position of the returned list contains a sublist with the corresponding particle indices."""
		indexMap = []
		mappedIndex = 0

		for strandIndex in self.sortedStrandsList:
			minStrandPoint = 0 if self.strands[strandIndex].IsPropagated() else self.params.lastFixedVertexIndex + 1
			strandPointCount = self.hair.GetStrandPointCount( self.strandToHairIndex[strandIndex] )
			for index in range( strandPointCount ):
				mappedIndices = []
				if index >= minStrandPoint:
					for latticeIndex in range( self.params.latticeCount ):
						mappedIndices.append( mappedIndex )
						mappedIndex += 1
				indexMap.append( mappedIndices )

		return indexMap

	def ConvolvePerStrand( self, vertexData, convolutionKernel ):
		"""Computes per-strand convolution of the vertexData array."""
		result = []

		for strandIndex in self.sortedStrandsList:
			strandHairIndex = self.strandToHairIndex[strandIndex]
			strandStartIndex = self.hair.GetFirstVertexIndex( strandHairIndex )
			strandVertexCount = self.hair.GetStrandPointCount( strandHairIndex )
			result.extend( self.ConvolveStrand( vertexData[strandStartIndex:strandStartIndex + strandVertexCount], convolutionKernel ) )

		return result

	@staticmethod
	def ConvolveStrand( strandData, convolutionKernel ):
		kernelLength = len( convolutionKernel )
		if kernelLength % 2 != 1:
			raise ValueError( "Only odd dimensions of convolution kernel supported" )
		kernelCenter = kernelLength // 2
		strandLength = len( strandData )
		result = []

		for index in range( strandLength ):
			kFrom = max( -index, -kernelCenter )
			kTo = min( strandLength - index, kernelCenter + 1 )
			value = 0
			for k in range( -kernelCenter, kernelCenter + 1 ):
				# take strand end values if outside range to avoid convolution decrease near ends
				strandValue = strandData[index + kFrom] if k < kFrom else strandData[index + kTo - 1] if k >= kTo else strandData[index + k]
				value += convolutionKernel[kernelCenter - k] * strandValue
			result.append( value )

		return result


class VertexChannel:
	"""Class allowing the exchange of guide vertex channel data between a HairModel and a Moov simulation.

	The hair channel data can be accessed for only a section of each strand and averaged over the hair lattice for each vertex.
	For a particle set containing the dynamic particles (particles 0 and 1 of each strand are non-dynamic) use 
	VertexChannel( hairModel, 'Name', minStrandPointIndex = 2, useLattice = True).
	"""
	def __init__( self, hairModel, channelName ):
		"""Initializes the vertex channel instance.

		Parameters:
			hairModel -- HairModel object containing the hair data
			channelName -- (string) name of the vertex channel
		"""
		self.hairModel = hairModel
		self.channelName = channelName

	def Get( self, createIfAbsent = False ):
		if createIfAbsent:
			channelIndex = self.hairModel.AddVertexDataChannel( self.channelName )
		else:
			channelIndex = self.hairModel.GetChannelIndex( self.channelName )
		if channelIndex < 0:
			raise RuntimeError( 'Could not obtain vertex data channel {0}'.format( self.channelName ) )
		return channelIndex

	def Initialize( self, value, createIfAbsent = False ):
		self.hairModel.InitChannelVertexData( self.Get( createIfAbsent ), value )

	def Write( self, data, createIfAbsent = False ):
		self.hairModel.SetChannelVertexData( self.Get( createIfAbsent ), data )

	def Read( self ):
		return self.hairModel.GetChannelVertexData( self.Get() )
